{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import cv2\n",
    "import torch\n",
    "from time import time\n",
    "import os\n",
    "import datetime\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import torch.nn.functional as F\n",
    "device = 'cuda'\n",
    "batch_size = 1\n",
    "grid_h,grid_w = 15,15\n",
    "H,W = height,width = 360,640"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_warp(net_out,img):\n",
    "    '''\n",
    "    Inputs:\n",
    "        net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])\n",
    "        img: image to warp\n",
    "    '''\n",
    "    grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n",
    "                                    torch.linspace(-1,1, grid_h + 1),\n",
    "                                    indexing='ij')\n",
    "    src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n",
    "    new_grid = src_grid + 1 * net_out\n",
    "    grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)\n",
    "    warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=True)\n",
    "    return warped"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StabNet(nn.Module):\n",
    "    def __init__(self,trainable_layers = 10):\n",
    "        super(StabNet, self).__init__()\n",
    "        # Load the pre-trained ResNet model\n",
    "        vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')\n",
    "        # Extract conv1 pretrained weights for RGB input\n",
    "        rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])\n",
    "        # Calculate the average across the RGB channels\n",
    "        tiled_rgb_weights = rgb_weights.repeat(1,5,1,1) \n",
    "        # Change size of the first layer from 3 to 9 channels\n",
    "        vgg19.features[0] = nn.Conv2d(15,64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        # set new weights\n",
    "        vgg19.features[0].weight = nn.Parameter(tiled_rgb_weights)\n",
    "        # Determine the total number of layers in the model\n",
    "        total_layers = sum(1 for _ in vgg19.parameters())\n",
    "        # Freeze the layers except the last 10\n",
    "        for idx, param in enumerate(vgg19.parameters()):\n",
    "            if idx > total_layers - trainable_layers:\n",
    "                param.requires_grad = True\n",
    "            else:\n",
    "                param.requires_grad = False\n",
    "        # Remove the last layer of ResNet\n",
    "        self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])\n",
    "        self.regressor = nn.Sequential(nn.Linear(512,2048),\n",
    "                                       nn.ReLU(),\n",
    "                                       nn.Linear(2048,1024),\n",
    "                                       nn.ReLU(),\n",
    "                                       nn.Linear(1024,512),\n",
    "                                       nn.ReLU(),\n",
    "                                       nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))\n",
    "        #self.regressor[-1].bias.data.fill_(0)\n",
    "        total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)\n",
    "        total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)\n",
    "        print(\"Total Trainable mobilenet Parameters: \", total_resnet_params)\n",
    "        print(\"Total Trainable regressor Parameters: \", total_regressor_params)\n",
    "        print(\"Total Trainable parameters:\",total_regressor_params + total_resnet_params)\n",
    "    \n",
    "    def forward(self, x_tensor):\n",
    "        x_batch_size = x_tensor.size()[0]\n",
    "        x = x_tensor[:, :3, :, :]\n",
    "\n",
    "        # summary 1, dismiss now\n",
    "        x_tensor = self.encoder(x_tensor)\n",
    "        x_tensor = torch.mean(x_tensor, dim=[2, 3])\n",
    "        x = self.regressor(x_tensor)\n",
    "        x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total Trainable mobilenet Parameters:  2360320\n",
      "Total Trainable regressor Parameters:  3936256\n",
      "Total Trainable parameters: 6296576\n",
      "loaded weights ./ckpts/with_future_frames/stabnet_2023-11-02_23-03-39.pth\n"
     ]
    }
   ],
   "source": [
    "ckpt_dir = './ckpts/with_future_frames/'\n",
    "stabnet = StabNet().to(device).eval()\n",
    "ckpts = os.listdir(ckpt_dir)\n",
    "if ckpts:\n",
    "    ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], \"%H-%M-%S\"), reverse=True)\n",
    "    \n",
    "    # Get the filename of the latest checkpoint\n",
    "    latest = os.path.join(ckpt_dir, ckpts[0])\n",
    "\n",
    "    state = torch.load(latest)\n",
    "    stabnet.load_state_dict(state['model'])\n",
    "    print('loaded weights',latest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = 'E:/Datasets/DeepStab_Dataset/unstable/2.avi'\n",
    "cap = cv2.VideoCapture(path)\n",
    "mean = np.array([0.485, 0.456, 0.406],dtype = np.float32) \n",
    "std = np.array([0.229, 0.224, 0.225],dtype = np.float32)\n",
    "frames = []\n",
    "while True:\n",
    "    ret, img = cap.read()\n",
    "    if not ret: break\n",
    "    img = cv2.resize(img, (W,H))\n",
    "    img = (img / 255.0).astype(np.float32)\n",
    "    img = (img - mean)/std\n",
    "    frames.append(img)\n",
    "frames = np.array(frames,dtype = np.float32)\n",
    "frame_count,_,_,_ = frames.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "speed: 0.007669420583669505 seconds per frame\n"
     ]
    }
   ],
   "source": [
    "frames_tensor = torch.from_numpy(frames).permute(0,3,1,2).float().to('cpu')\n",
    "stable_frames_tensor = frames_tensor.clone()\n",
    "\n",
    "SKIP = 32\n",
    "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
    "def get_batch(idx):\n",
    "    batch = torch.zeros((5,3,H,W)).float()\n",
    "    for i,j in enumerate(range(idx - SKIP, idx + SKIP + 1, SKIP//2)):\n",
    "            batch[i,...] = frames_tensor[j,...]\n",
    "    batch = batch.view(1,-1,H,W)\n",
    "    return batch.to(device)\n",
    "start = time()\n",
    "for frame_idx in range(SKIP,frame_count - SKIP):\n",
    "    batch = get_batch(frame_idx)\n",
    "    with torch.no_grad():\n",
    "        transform = stabnet(batch)\n",
    "        warped = get_warp(transform, frames_tensor[frame_idx: frame_idx + 1,...].cuda())\n",
    "    stable_frames_tensor[frame_idx] = warped\n",
    "    img = warped.permute(0,2,3,1)[0,...].cpu().detach().numpy()\n",
    "    img *= std\n",
    "    img += mean\n",
    "    img = np.clip(img * 255.0,0,255).astype(np.uint8)\n",
    "    cv2.imshow('window', img)\n",
    "    if cv2.waitKey(1) & 0xFF == ord('q'):\n",
    "        break\n",
    "cv2.destroyAllWindows()\n",
    "total = time() - start\n",
    "speed = total / frame_count\n",
    "print(f'speed: {speed} seconds per frame')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "stable_frames = np.clip(((stable_frames_tensor.permute(0,2,3,1).numpy() * std) + mean) * 255,0,255).astype(np.uint8)\n",
    "frames = np.clip(((frames_tensor.permute(0,2,3,1).numpy() * std) + mean) * 255,0,255).astype(np.uint8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from time import sleep\n",
    "fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n",
    "out = cv2.VideoWriter('2.avi', fourcc, 30.0, (W,H))\n",
    "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
    "for idx in range(frame_count):\n",
    "    img = stable_frames[idx,...]\n",
    "    out.write(img)\n",
    "    cv2.imshow('window',img)\n",
    "    #sleep(1/30)\n",
    "    if cv2.waitKey(1) & 0xFF == ord(' '):\n",
    "        break\n",
    "out.release()\n",
    "cv2.destroyAllWindows()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Frame: 446/447\n",
      "cropping score:0.996\tdistortion score:0.982\tstability:0.666\tpixel:0.997\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0.9962942926265075, 0.9820633, 0.66624937150459, 0.9968942715786397)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from metrics import metric\n",
    "metric('E:/Datasets/DeepStab_Dataset/unstable/2.avi','2.avi')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "DUTCode",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
